-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
outline for loading old sleap config #162
base: main
Are you sure you want to change the base?
Conversation
WalkthroughThis change introduces four new functions across different configuration modules: Changes
Possibly related issues
Possibly related PRs
Suggested reviewers
Poem
✨ Finishing Touches
🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 4
🧹 Nitpick comments (3)
sleap_nn/config/training_job_config.py (3)
154-162
: Missing docstring for new functionUnlike other functions in this file,
load_sleap_config
lacks a docstring describing its purpose, parameters, and return value.Add a descriptive docstring:
def load_sleap_config(cls, json_file_path: str) -> TrainerConfig: + """Load a training job configuration from a legacy SLEAP JSON config file. + + Arguments: + json_file_path: Path to a legacy SLEAP JSON configuration file. + + Returns: + A TrainingJobConfig instance containing the mapped configuration data. + """ with open(json_file_path, "r") as f: old_config = json.load(f)
158-161
: Missing error handling for mapper functionsThere's no error handling for potential exceptions in the mapper functions.
Add error handling for the mapper functions:
- data_config = data_mapper(old_config) - model_config = model_mapper(old_config) - trainer_config = trainer_mapper(old_config) + try: + data_config = data_mapper(old_config) + model_config = model_mapper(old_config) + trainer_config = trainer_mapper(old_config) + except Exception as e: + raise ValueError(f"Failed to map configuration from {json_file_path}: {str(e)}")
162-162
: Filename attribute not set in the created configUnlike the YAML loading methods, this function doesn't set the
filename
attribute to the source file path.Set the filename attribute in the returned config:
- return cls(data_config=data_config, model_config=model_config, trainer_config=trainer_config) + return cls( + data_config=data_config, + model_config=model_config, + trainer_config=trainer_config, + filename=json_file_path + )
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
sleap_nn/config/training_job_config.py
(2 hunks)
⏰ Context from checks skipped due to timeout of 90000ms (4)
- GitHub Check: Tests (macos-14, Python 3.9)
- GitHub Check: Tests (windows-latest, Python 3.9)
- GitHub Check: Tests (ubuntu-latest, Python 3.9)
- GitHub Check: Lint
with open(json_file_path, "r") as f: | ||
old_config = json.load(f) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Missing error handling for file operations and JSON parsing
The function doesn't handle potential exceptions when opening the file or parsing the JSON content.
Add error handling for file operations and JSON parsing:
- with open(json_file_path, "r") as f:
- old_config = json.load(f)
+ try:
+ with open(json_file_path, "r") as f:
+ old_config = json.load(f)
+ except FileNotFoundError:
+ raise FileNotFoundError(f"Configuration file not found: {json_file_path}")
+ except json.JSONDecodeError:
+ raise ValueError(f"Invalid JSON in configuration file: {json_file_path}")
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
with open(json_file_path, "r") as f: | |
old_config = json.load(f) | |
try: | |
with open(json_file_path, "r") as f: | |
old_config = json.load(f) | |
except FileNotFoundError: | |
raise FileNotFoundError(f"Configuration file not found: {json_file_path}") | |
except json.JSONDecodeError: | |
raise ValueError(f"Invalid JSON in configuration file: {json_file_path}") |
model_config = model_mapper(old_config) | ||
trainer_config = trainer_mapper(old_config) | ||
|
||
return cls(data_config=data_config, model_config=model_config, trainer_config=trainer_config) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Verification agent
🧩 Analysis chain
Missing validation for loaded configuration
The function doesn't apply the check_output_strides
validation that's performed in the YAML loading methods.
Consider whether the check_output_strides
validation should be applied to JSON loaded configurations as well. This validation ensures compatibility between backbone and head configurations.
🏁 Script executed:
#!/bin/bash
# Check if check_output_strides is applied in any other loaders
grep -rn "check_output_strides" --include="*.py" ./
Length of output: 379
Action Required: Apply Missing check_output_strides
Validation in JSON Loading
It appears that the JSON configuration loading path in sleap_nn/config/training_job_config.py
(culminating in the return statement at line 162) does not invoke the check_output_strides
validation. In contrast, YAML loaders in this file call TrainingJobConfig.check_output_strides(config)
(as seen on lines 99 and 117). For consistency and to ensure compatibility between backbone and head configurations, please consider applying the same validation for JSON-loaded configurations.
- Evaluate adding a call to
TrainingJobConfig.check_output_strides(config)
right before the instance initialization in the JSON loader. - Confirm that this adjustment aligns with the intended configuration compatibility checks across all loaders.
@@ -146,3 +150,13 @@ | |||
The parsed `OmegaConf`. | |||
""" | |||
return TrainingJobConfig.load_yaml(filename) | |||
|
|||
def load_sleap_config(cls, json_file_path: str) -> TrainerConfig: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Function lacks classmethod decorator despite using cls
parameter
This function appears to be intended as a class method since it takes cls
as its first parameter, but it's missing the @classmethod
decorator and is defined at module level rather than within the TrainingJobConfig
class.
Consider moving this function into the TrainingJobConfig
class and adding the @classmethod
decorator:
-def load_sleap_config(cls, json_file_path: str) -> TrainerConfig:
+@classmethod
+def load_sleap_config(cls, json_file_path: str) -> "TrainingJobConfig":
Committable suggestion skipped: line range outside the PR's diff.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
sleap_nn/config/data_config.py
(1 hunks)sleap_nn/config/model_config.py
(1 hunks)
⏰ Context from checks skipped due to timeout of 90000ms (4)
- GitHub Check: Tests (macos-14, Python 3.9)
- GitHub Check: Tests (windows-latest, Python 3.9)
- GitHub Check: Tests (ubuntu-latest, Python 3.9)
- GitHub Check: Lint
🔇 Additional comments (1)
sleap_nn/config/data_config.py (1)
215-221
:✅ Verification successful
Verify the probability values for augmentation configuration
The probability values (
uniform_noise_p
,gaussian_noise_p
,contrast_p
) are set to default values of 1.0, but the base classIntensityConfig
defines them with defaults of 0.0. Additionally, validation for these fields ensures they are between 0.0 and 1.0.Verify whether the default values of 1.0 from the legacy configuration are correct. This seems inconsistent with the class definitions where these probabilities default to 0.0, suggesting they might not be applied by default.
🏁 Script executed:
#!/bin/bash # Check if there's any precedent in the codebase for using 1.0 as default probability rg -A 2 -B 2 "noise_p|contrast_p|brightness_p"Length of output: 32418
Legacy augmentation probability defaults confirmed
After reviewing the legacy configuration across the codebase and corresponding tests, it appears that the default values of 1.0 for
uniform_noise_p
,gaussian_noise_p
, andcontrast_p
in the legacy configuration are intentional. Although the base class (IntensityConfig
) defaults these probabilities to 0.0 and enforces a [0.0, 1.0] range, the legacy branch explicitly overrides them to 1.0—as validated by tests in files liketests/test_train.py
andtests/data/test_augmentation.py
. To avoid future confusion, consider documenting this intentional behavior and noting the divergence from the new defaults.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
♻️ Duplicate comments (1)
sleap_nn/config/data_config.py (1)
189-243
:⚠️ Potential issueInconsistency in parameter handling
The function implementation has several issues that need to be addressed:
- Many parameters are commented out (lines 191-201) including critical path parameters like
train_labels_path
andval_labels_path
which are marked asMISSING
in the DataConfig class.- Line 210 comments out
use_augmentations_train
, but it's used in the conditional at line 241.- The structure of the legacy_config appears inconsistent with some values coming from
data.preprocessing
and others fromoptimization.augmentation_config
.def data_mapper(legacy_config: dict) -> DataConfig: return DataConfig( - # train_labels_path=legacy_config.get("train_labels_path", MISSING), - # val_labels_path=legacy_config.get("val_labels_path", MISSING), + train_labels_path=legacy_config.get("train_labels_path", MISSING), + val_labels_path=legacy_config.get("val_labels_path", MISSING), # test_file_path=legacy_config.get("test_file_path"), # provider=legacy_config.get("provider", "LabelsReader"), # user_instances_only=legacy_config.get("user_instances_only", True), # data_pipeline_fw=legacy_config.get("data_pipeline_fw", "torch_dataset"), # np_chunks_path=legacy_config.get("np_chunks_path"), # litdata_chunks_path=legacy_config.get("litdata_chunks_path"), # use_existing_chunks=legacy_config.get("use_existing_chunks", False), # chunk_size=int(legacy_config.get("chunk_size", 100)), # delete_chunks_after_training=legacy_config.get("delete_chunks_after_training", True), preprocessing=PreprocessingConfig( is_rgb=legacy_config.get("data", {}).get("preprocessing", {}).get("ensure_rgb", False), max_height=legacy_config.get("data", {}).get("preprocessing", {}).get("target_height"), max_width=legacy_config.get("data", {}).get("preprocessing", {}).get("target_width"), scale=legacy_config.get("data", {}).get("preprocessing", {}).get("input_scaling", 1.0), crop_hw=legacy_config.get("data", {}).get("preprocessing", {}).get("crop_size"), min_crop_size=legacy_config.get("data", {}).get("preprocessing", {}).get("crop_size_detection_padding", 100), ), - # use_augmentations_train=legacy_config.get("use_augmentations_train", False), + use_augmentations_train=legacy_config.get("use_augmentations_train", False), augmentation_config=AugmentationConfig( intensity=IntensityConfig( uniform_noise_min=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("uniform_noise_min_val", 0.0),
🧹 Nitpick comments (1)
sleap_nn/config/data_config.py (1)
189-190
: Add docstring to data_mapper functionThe
data_mapper
function lacks a docstring to explain its purpose, parameters, and return value. This is important for maintainability and understanding the code's intent.+""" +Convert a legacy SLEAP configuration dictionary to a DataConfig object. + +Args: + legacy_config: Dictionary containing configuration from older SLEAP version + +Returns: + DataConfig: A structured configuration object matching the legacy settings +""" def data_mapper(legacy_config: dict) -> DataConfig:
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
sleap_nn/config/data_config.py
(1 hunks)
⏰ Context from checks skipped due to timeout of 90000ms (4)
- GitHub Check: Tests (macos-14, Python 3.9)
- GitHub Check: Tests (windows-latest, Python 3.9)
- GitHub Check: Tests (ubuntu-latest, Python 3.9)
- GitHub Check: Lint
🔇 Additional comments (2)
sleap_nn/config/data_config.py (2)
222-224
: Fixed brightness tuple creationThe brightness tuple is now correctly created using parentheses, fixing the syntax error from the previous review.
213-221
:❓ Verification inconclusive
Inconsistent default values for noise and contrast parameters
The default values for noise and contrast parameters don't match the class defaults. The class defaults are:
uniform_noise_p
: 0.0 in class, 1.0 in mappergaussian_noise_p
: 0.0 in class, 1.0 in mappercontrast_p
: 0.0 in class, 1.0 in mapperPlease verify if these different default values are intentional or if they should match the class defaults.
🏁 Script executed:
#!/bin/bash # Check where these values are used elsewhere in the codebase echo "Searching for uniform_noise, gaussian_noise, and contrast default usages:" rg --type py "uniform_noise|gaussian_noise|contrast.*p" -A 2 -B 2Length of output: 34468
Action Required: Verify Default Augmentation Parameter Values
The legacy configuration mapper in
sleap_nn/config/data_config.py
is currently defaulting the following parameters to 1.0:
uniform_noise_p
gaussian_noise_p
contrast_p
However, the corresponding defaults defined in the augmentation module (
sleap_nn/data/augmentation.py
) are 0.0 for these parameters. This discrepancy may lead to inconsistent augmentation behavior depending on whether values come from the legacy config or are used directly from the class defaults. Some tests (e.g., intests/test_train.py
andtests/data/test_augmentation.py
) appear to expect a default of 1.0 when augmentations are explicitly enabled, while the augmentation module itself defines a “disabled” default of 0.0.Please verify whether the legacy mapping defaults are intended to override the class defaults. If the intended behavior is to have augmentations off by default, update the mapper in
sleap_nn/config/data_config.py
to use 0.0 (matching the class defaults). Otherwise, if the legacy behavior (defaulting to 1.0) is deliberate, ensure that the discrepancies are well documented and that the tests reflect this intended behavior.
sleap_nn/config/data_config.py
Outdated
geometric=GeometricConfig( | ||
rotation=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("rotation_max_angle", 180.0), | ||
scale=(legacy_config.get("optimization", {}).get("augmentation_config", {}).get("scale_min", None), | ||
legacy_config.get("optimization", {}).get("augmentation_config", {}).get("scale_max", None)), | ||
# translate_width=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("translate_width", 0.2), | ||
# translate_height=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("translate_height", 0.2), | ||
# affine_p=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("affine_p", 0.0), | ||
# erase_scale_min=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("erase_scale_min", 0.0001), | ||
# erase_scale_max=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("erase_scale_max", 0.01), | ||
# erase_ratio_min=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("erase_ratio_min", 1.0), | ||
# erase_ratio_max=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("erase_ratio_max", 1.0), | ||
# erase_p=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("erase_p", 0.0), | ||
# mixup_lambda=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("mixup_lambda", [0.01, 0.05]), | ||
# mixup_p=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("mixup_p", 0.0), | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Potential issues with geometric configuration parameters
There are several issues with the geometric configuration section:
- The scale parameter uses
None
as potential values which might conflict with the validator in theGeometricConfig
class - Most parameters are commented out without a clear indication why
- The rotation parameter uses a different default (180.0) compared to the class default (15.0)
geometric=GeometricConfig(
rotation=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("rotation_max_angle", 180.0),
- scale=(legacy_config.get("optimization", {}).get("augmentation_config", {}).get("scale_min", None),
- legacy_config.get("optimization", {}).get("augmentation_config", {}).get("scale_max", None)),
+ scale=(legacy_config.get("optimization", {}).get("augmentation_config", {}).get("scale_min", 0.9),
+ legacy_config.get("optimization", {}).get("augmentation_config", {}).get("scale_max", 1.1)),
# translate_width=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("translate_width", 0.2),
# translate_height=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("translate_height", 0.2),
Consider adding a comment to explain why most of these parameters are commented out, or uncomment them if they should be included.
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
geometric=GeometricConfig( | |
rotation=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("rotation_max_angle", 180.0), | |
scale=(legacy_config.get("optimization", {}).get("augmentation_config", {}).get("scale_min", None), | |
legacy_config.get("optimization", {}).get("augmentation_config", {}).get("scale_max", None)), | |
# translate_width=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("translate_width", 0.2), | |
# translate_height=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("translate_height", 0.2), | |
# affine_p=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("affine_p", 0.0), | |
# erase_scale_min=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("erase_scale_min", 0.0001), | |
# erase_scale_max=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("erase_scale_max", 0.01), | |
# erase_ratio_min=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("erase_ratio_min", 1.0), | |
# erase_ratio_max=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("erase_ratio_max", 1.0), | |
# erase_p=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("erase_p", 0.0), | |
# mixup_lambda=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("mixup_lambda", [0.01, 0.05]), | |
# mixup_p=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("mixup_p", 0.0), | |
) | |
geometric=GeometricConfig( | |
rotation=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("rotation_max_angle", 180.0), | |
scale=(legacy_config.get("optimization", {}).get("augmentation_config", {}).get("scale_min", 0.9), | |
legacy_config.get("optimization", {}).get("augmentation_config", {}).get("scale_max", 1.1)), | |
# translate_width=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("translate_width", 0.2), | |
# translate_height=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("translate_height", 0.2), | |
# affine_p=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("affine_p", 0.0), | |
# erase_scale_min=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("erase_scale_min", 0.0001), | |
# erase_scale_max=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("erase_scale_max", 0.01), | |
# erase_ratio_min=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("erase_ratio_min", 1.0), | |
# erase_ratio_max=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("erase_ratio_max", 1.0), | |
# erase_p=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("erase_p", 0.0), | |
# mixup_lambda=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("mixup_lambda", [0.01, 0.05]), | |
# mixup_p=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("mixup_p", 0.0), | |
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
♻️ Duplicate comments (2)
sleap_nn/config/model_config.py (1)
906-917
:⚠️ Potential issueAlways creating centroid head violates one-of semantics.
Because
centroid
is set unconditionally, it will be active alongside any other head, contradicting the@oneof
policy that only one head can be non-None. Wrap this in a condition to ensure that we only create the centroid head ifhead_type
is set to"centroid"
.- centroid=CentroidConfig( + centroid=( + CentroidConfig( + confmaps=CentroidConfMapsConfig( + anchor_part=legacy_config.get("heads", {}) + .get("centroid", {}) + .get("anchor_part"), + sigma=legacy_config.get("heads", {}) + .get("centroid", {}) + .get("sigma", 5.0), + output_stride=legacy_config.get("heads", {}) + .get("centroid", {}) + .get("output_stride", 1), + ) + ) + if legacy_config.get("head_type") == "centroid" + else None + ),sleap_nn/config/data_config.py (1)
270-277
:⚠️ Potential issueScale set to
(None, None)
may break validation logic.By default,
GeometricConfig.scale
expects numeric tuples (e.g.,(0.9, 1.1)
). Using(None, None)
can trigger type errors or produce unintended outcomes. Provide valid numeric values if you want to enable scaling, or set this toNone
if scaling is not used.- scale=( - legacy_config.get("optimization", {}).get("augmentation_config", {}).get("scale_min", None), - legacy_config.get("optimization", {}).get("augmentation_config", {}).get("scale_max", None), - ), + scale=( + legacy_config.get("optimization", {}).get("augmentation_config", {}).get("scale_min", 0.9), + legacy_config.get("optimization", {}).get("augmentation_config", {}).get("scale_max", 1.1), + ),
🧹 Nitpick comments (2)
sleap_nn/config/model_config.py (1)
835-836
: Remove or clarify leftover comment.These lines appear to contain a partial or uncertain comment about
pretrained_backbone_weights
(e.g.,)?? # i think its different
). This can be confusing to future maintainers. Consider removing or clarifying these comments to ensure the code is clean and unambiguous.- # pretrained_backbone_weights=legacy_config.get("PretrainedEncoderConfig")?? # i think its different + # pretrained_backbone_weights=legacy_config.get("PretrainedEncoderConfig") # Clarify if needed or remove entirelysleap_nn/config/data_config.py (1)
223-224
:use_augmentations_train
is commented out, causing mismatch.The
DataConfig
instance setsuse_augmentations_train
to its defaultFalse
, even iflegacy_config
indicates otherwise. Uncomment and mapuse_augmentations_train
fromlegacy_config
to keep the final config in sync.- # use_augmentations_train=legacy_config.get("use_augmentations_train", False), + use_augmentations_train=legacy_config.get("use_augmentations_train", False),
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
sleap_nn/config/data_config.py
(1 hunks)sleap_nn/config/model_config.py
(1 hunks)sleap_nn/config/training_job_config.py
(2 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- sleap_nn/config/training_job_config.py
⏰ Context from checks skipped due to timeout of 90000ms (4)
- GitHub Check: Tests (macos-14, Python 3.9)
- GitHub Check: Tests (windows-latest, Python 3.9)
- GitHub Check: Tests (ubuntu-latest, Python 3.9)
- GitHub Check: Lint
🔇 Additional comments (4)
sleap_nn/config/model_config.py (2)
837-842
: Confirm usage of commented-out parameters.In the returned
ModelConfig
, multiple parameters likeinit_weights
andpre_trained_weights
are commented out (lines 839-841). If these are intentionally omitted from legacy handling, consider removing them altogether. Otherwise, uncomment and map them fromlegacy_config
to ensure backward compatibility.
889-905
: Conditional creation for single_instance head looks good.The code correctly creates a
SingleInstanceConfig
only whenlegacy_config.get("head_type") == "single_instance"
. This aligns with the@oneof
requirement that only one head should be active.sleap_nn/config/data_config.py (2)
190-202
: Review whether commented-out fields should be restored or removed.Several fields (e.g.,
train_labels_path
,val_labels_path
, etc.) are commented out, preventing them from being set inDataConfig
. Verify if you need these fields for accurate data loading. If they are essential, uncomment and map them; otherwise, remove them to avoid confusion.
267-269
: Rotation uses 180.0 by default, diverging from internal default of 15.0.This large rotation range can significantly alter training behavior. If this is intentional, confirm it aligns with your training objectives. Otherwise, consider a smaller angle or the class default.
…ers, early_stopping, and commented out unspecified parameters
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Nitpick comments (6)
sleap_nn/config/trainer_config.py (6)
236-236
: Add a docstring to thetrainer_mapper
function.The function is missing documentation explaining its purpose, parameters, and return value. A docstring would improve maintainability and help other developers understand the function's role.
def trainer_mapper(legacy_config: dict) -> TrainerConfig: + """Convert a legacy configuration dictionary to a TrainerConfig instance. + + This function extracts values from a legacy SLEAP configuration dictionary + and maps them to the appropriate fields in the TrainerConfig structure. + + Args: + legacy_config: A dictionary containing legacy configuration parameters. + + Returns: + A TrainerConfig instance populated with values from the legacy config. + """ return TrainerConfig(
243-246
: Decide on commented-out code.There are numerous commented-out sections throughout the function. These make the code harder to read and maintain. Either implement these sections if they're needed or remove them if they're not.
If these configurations are planned for future implementation, consider adding a TODO comment indicating this, or remove them entirely if they're not needed.
311-350
: Simplify conditional configuration creation.The conditional creation of
lr_scheduler
is complex and nested. Consider extracting this logic to a separate helper function to improve readability.+ lr_scheduler=_create_lr_scheduler_from_legacy(legacy_config), - lr_scheduler=( - LRSchedulerConfig( - # step_lr=StepLRConfig( - # step_size=legacy_config.get("optimization", {}) - # .get("lr_scheduler", {}) - # .get("step_lr", {}) - # .get("step_size", 10), - # gamma=legacy_config.get("optimization", {}) - # .get("lr_scheduler", {}) - # .get("step_lr", {}) - # .get("gamma", 0.1), - # ) if legacy_config.get("optimization", {}).get("lr_scheduler", {}).get("scheduler") == "StepLR" else None, - reduce_lr_on_plateau=ReduceLROnPlateauConfig( - # threshold=legacy_config.get("optimization", {}) - # .get("lr_scheduler", {}) - # .get("reduce_lr_on_plateau", {}) - # .get("threshold", 1e-4), - # threshold_mode=legacy_config.get("optimization", {}) - # .get("lr_scheduler", {}) - # .get("reduce_lr_on_plateau", {}) - # .get("threshold_mode", "rel"), - # cooldown=legacy_config.get("optimization", {}) - # .get("lr_scheduler", {}) - # .get("reduce_lr_on_plateau", {}) - # .get("cooldown", 0), - patience=legacy_config.get("optimization", {}) - .get("learning_rate_schedule", {}) - .get("plateau_patience", 10), - # factor=legacy_config.get("optimization", {}) - # .get("lr_scheduler", {}) - # .get("reduce_lr_on_plateau", {}) - # .get("factor", 0.1), - min_lr=legacy_config.get("optimization", {}) - .get("learning_rate_schedule", {}) - .get("min_learning_rate", 0.0), - ) - ) - if legacy_config.get("optimization", {}).get("learning_rate_schedule") - else None - ),And add this helper function above:
def _create_lr_scheduler_from_legacy(legacy_config: dict) -> Optional[LRSchedulerConfig]: """Create LRSchedulerConfig from legacy configuration. Args: legacy_config: A dictionary containing legacy configuration parameters. Returns: A LRSchedulerConfig instance or None if learning_rate_schedule is not specified. """ optimization = legacy_config.get("optimization", {}) learning_rate_schedule = optimization.get("learning_rate_schedule") if not learning_rate_schedule: return None return LRSchedulerConfig( reduce_lr_on_plateau=ReduceLROnPlateauConfig( patience=learning_rate_schedule.get("plateau_patience", 10), min_lr=learning_rate_schedule.get("min_learning_rate", 0.0), ) )
351-365
: Similar to the lr_scheduler, extract early_stopping config creation to a helper function.The conditional creation of
early_stopping
follows the same pattern aslr_scheduler
and could benefit from the same refactoring approach.+ early_stopping=_create_early_stopping_from_legacy(legacy_config), - early_stopping=( - EarlyStoppingConfig( - stop_training_on_plateau=legacy_config.get("optimization", {}) - .get("learning_rate_schedule", {}) - .get("reduce_on_plateau", False), - min_delta=legacy_config.get("optimization", {}) - .get("learning_rate_schedule", {}) - .get("plateau_min_delta", 0.0), - patience=legacy_config.get("optimization", {}) - .get("learning_rate_schedule", {}) - .get("plateau_patience", 1), - ) - if legacy_config.get("optimization", {}).get("learning_rate_schedule") - else None - ),And add this helper function below the
_create_lr_scheduler_from_legacy
function:def _create_early_stopping_from_legacy(legacy_config: dict) -> Optional[EarlyStoppingConfig]: """Create EarlyStoppingConfig from legacy configuration. Args: legacy_config: A dictionary containing legacy configuration parameters. Returns: An EarlyStoppingConfig instance or None if learning_rate_schedule is not specified. """ optimization = legacy_config.get("optimization", {}) learning_rate_schedule = optimization.get("learning_rate_schedule") if not learning_rate_schedule: return None return EarlyStoppingConfig( stop_training_on_plateau=learning_rate_schedule.get("reduce_on_plateau", False), min_delta=learning_rate_schedule.get("plateau_min_delta", 0.0), patience=learning_rate_schedule.get("plateau_patience", 1), )
236-366
: Add input validation for the legacy_config parameter.The function doesn't validate the input
legacy_config
parameter. Consider adding validation to ensure it's a dictionary and contains the expected structure.Add validation at the beginning of the function:
def trainer_mapper(legacy_config: dict) -> TrainerConfig: + """Convert a legacy configuration dictionary to a TrainerConfig instance. + + This function extracts values from a legacy SLEAP configuration dictionary + and maps them to the appropriate fields in the TrainerConfig structure. + + Args: + legacy_config: A dictionary containing legacy configuration parameters. + + Returns: + A TrainerConfig instance populated with values from the legacy config. + + Raises: + TypeError: If legacy_config is not a dictionary. + """ + if not isinstance(legacy_config, dict): + raise TypeError("legacy_config must be a dictionary") + return TrainerConfig(
236-366
: Consider using dictionary access helpers to simplify deeply nested dictionary accesses.The function contains many instances of deeply nested dictionary access using chained
.get()
calls, which makes the code hard to read. Consider adding a helper function to simplify these accesses.Add a helper function at the beginning of the file:
def get_nested_dict_value(dictionary: dict, keys: list, default=None): """Safely get a value from a nested dictionary. Args: dictionary: The dictionary to search in. keys: A list of keys representing the path to the value. default: The default value to return if the key path doesn't exist. Returns: The value at the key path, or the default value if not found. """ current = dictionary for key in keys: if not isinstance(current, dict): return default current = current.get(key, {}) return current or defaultThen use it in the function:
- batch_size=legacy_config.get("optimization", {}).get("batch_size", 1), + batch_size=get_nested_dict_value(legacy_config, ["optimization", "batch_size"], 1),
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
sleap_nn/config/trainer_config.py
(1 hunks)
⏰ Context from checks skipped due to timeout of 90000ms (4)
- GitHub Check: Tests (macos-14, Python 3.9)
- GitHub Check: Tests (windows-latest, Python 3.9)
- GitHub Check: Tests (ubuntu-latest, Python 3.9)
- GitHub Check: Lint
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #162 +/- ##
==========================================
- Coverage 97.01% 91.57% -5.45%
==========================================
Files 46 46
Lines 4961 4983 +22
==========================================
- Hits 4813 4563 -250
- Misses 148 420 +272 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (2)
tests/config/test_trainer_config.py (1)
239-297
: Consider verifying the val_data_loader configuration as well.The new test comprehensively checks most mapped parameters, but it might be beneficial to also verify that
val_data_loader
is assigned the expected defaults (e.g.,batch_size
,shuffle
, etc.) since it mirrors the training data loader settings. This can increase coverage and potentially catch future regressions.sleap_nn/config/trainer_config.py (1)
234-367
: Remove or implement the commented-out fields for clarity.Most fields are commented out (e.g.,
num_workers
,amsgrad
,step_lr
) and can cause confusion about whether they are fully supported. Consider either removing them if they are no longer needed or fully implementing them. This will help maintain code clarity and consistency with the rest of the configuration.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
sleap_nn/config/trainer_config.py
(1 hunks)tests/config/test_trainer_config.py
(2 hunks)
⏰ Context from checks skipped due to timeout of 90000ms (4)
- GitHub Check: Tests (macos-14, Python 3.9)
- GitHub Check: Tests (windows-latest, Python 3.9)
- GitHub Check: Tests (ubuntu-latest, Python 3.9)
- GitHub Check: Lint
🔇 Additional comments (1)
tests/config/test_trainer_config.py (1)
23-23
: Looks good addingtrainer_mapper
to the import list.No concerns here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (2)
tests/config/test_trainer_config.py (1)
239-297
: Consider expanding test coverage for other checkpoint settings.
This test function looks comprehensive, verifying both mapped values and defaults. As an enhancement, consider adding an additional test scenario where"checkpointing"
contains"latest_model": True"
to confirmconfig.save_ckpt
is correctly set toTrue
.+ legacy_config["optimization"]["checkpointing"] = {"latest_model": True} + config = trainer_mapper(legacy_config) + assert config.save_ckpt is Truetests/config/test_model_config.py (1)
123-171
: Add optional scenarios for other heads.
The current test thoroughly verifies thesingle_instance
head when usingunet
. A more exhaustive approach might also test other heads or backbone types if legacy configurations include them (e.g.,centroid
,centered_instance
,multi_instance
).
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
sleap_nn/config/model_config.py
(1 hunks)tests/config/test_model_config.py
(2 hunks)tests/config/test_trainer_config.py
(2 hunks)
🧰 Additional context used
🧬 Code Definitions (2)
tests/config/test_trainer_config.py (2)
sleap_nn/config/trainer_config.py (1)
trainer_mapper
(236-366)tests/fixtures/datasets.py (1)
config
(46-174)
tests/config/test_model_config.py (1)
sleap_nn/config/model_config.py (1)
model_mapper
(837-1002)
⏰ Context from checks skipped due to timeout of 90000ms (4)
- GitHub Check: Tests (macos-14, Python 3.9)
- GitHub Check: Tests (windows-latest, Python 3.9)
- GitHub Check: Tests (ubuntu-latest, Python 3.9)
- GitHub Check: Lint
🔇 Additional comments (4)
tests/config/test_trainer_config.py (1)
23-23
: Import appears consistent.
Addingtrainer_mapper
import looks straightforward and does not introduce issues.tests/config/test_model_config.py (1)
121-122
: No functional changes detected.
These lines appear to be empty or whitespace. No action needed.sleap_nn/config/model_config.py (2)
835-836
: No functional changes detected.
These lines appear to be blank lines or whitespace. No action needed.
837-1003
:centered_instance
incorrectly referencesCentroidConfMapsConfig
.
This code usesCenteredInstanceConfig
but suppliesCentroidConfMapsConfig
forconfmaps
, which is likely a mismatch. Instead, useCenteredInstanceConfMapsConfig
to align with thecentered_instance
design.- confmaps=CentroidConfMapsConfig( + confmaps=CenteredInstanceConfMapsConfig( anchor_part=legacy_config.get("model", {}) .get("heads", {}) .get("centered_instance", {}) .get("anchor_part"), sigma=legacy_config.get("model", {}) .get("heads", {}) .get("centered_instance", {}) .get("sigma", 5.0), output_stride=legacy_config.get("model", {}) .get("heads", {}) .get("centered_instance", {}) .get("output_stride", 1), part_names=legacy_config.get("model", {}) .get("heads", {}) .get("centered_instance", {}) .get("part_names", None), )
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
sleap_nn/config/data_config.py
(1 hunks)tests/config/test_data_config.py
(2 hunks)
🧰 Additional context used
🧬 Code Definitions (1)
tests/config/test_data_config.py (1)
sleap_nn/config/data_config.py (1)
data_mapper
(190-295)
⏰ Context from checks skipped due to timeout of 90000ms (4)
- GitHub Check: Tests (macos-14, Python 3.9)
- GitHub Check: Tests (windows-latest, Python 3.9)
- GitHub Check: Tests (ubuntu-latest, Python 3.9)
- GitHub Check: Lint
🔇 Additional comments (2)
tests/config/test_data_config.py (2)
20-20
: Import looks good.The addition of
data_mapper
to the import statement correctly makes the function available for testing.
127-196
: Well-structured and comprehensive test function.The test function thoroughly validates the
data_mapper
function by:
- Creating a realistic legacy config
- Calling the mapper function
- Checking all relevant properties of the returned configuration
The test cases cover preprocessing, augmentation, intensity, geometric, and skeleton configurations.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (2)
sleap_nn/config/data_config.py (2)
277-284
: 🛠️ Refactor suggestionFix scale parameter to use valid values instead of None.
The
scale
parameter is set with potentiallyNone
values, which could cause validation errors since theGeometricConfig
class expects a list of floats with default(0.9, 1.1)
.scale=( legacy_config.get("optimization", {}) .get("augmentation_config", {}) - .get("scale_min", None), + .get("scale_min", 0.9), legacy_config.get("optimization", {}) .get("augmentation_config", {}) - .get("scale_max", None), + .get("scale_max", 1.1), ),
297-300
: 🛠️ Refactor suggestionFix inconsistent handling of use_augmentations_train.
The function:
- Has a commented-out conditional check for
augmentation_config
(lines 297-298)- Hardcodes
use_augmentations_train=True
(line 300)This creates inconsistency with the
DataConfig
class default (False
) and makes the augmentation configuration ignore the actual setting.) - # if legacy_config.get("use_augmentations_train", False) - # else None + if legacy_config.get("optimization", {}).get("use_augmentations", True) + else None ), - use_augmentations_train=True, + use_augmentations_train=legacy_config.get("optimization", {}).get("use_augmentations", True),
🧹 Nitpick comments (10)
tests/assets/fixtures/datasets.py (1)
4-4
: Remove the unused import to adhere to cleanliness.Line 4 imports
OmegaConf
but it's never used in this file. Please remove it to keep the file lean and avoid unnecessary dependencies.- from omegaconf import OmegaConf
🧰 Tools
🪛 Ruff (0.8.2)
4-4:
omegaconf.OmegaConf
imported but unusedRemove unused import:
omegaconf.OmegaConf
(F401)
tests/config/test_training_job_config.py (3)
36-36
: Remove the unused import forsleapnn_data_dir
if it's not needed.Static analysis indicates
sleapnn_data_dir
is unused. Review whether this fixture is needed; if not, remove it to eliminate the lint warning and keep dependencies minimal.- from tests.assets.fixtures.datasets import sleapnn_data_dir, training_job_config_path + from tests.assets.fixtures.datasets import training_job_config_path🧰 Tools
🪛 Ruff (0.8.2)
36-36:
tests.assets.fixtures.datasets.sleapnn_data_dir
imported but unusedRemove unused import
(F401)
36-36:
tests.assets.fixtures.datasets.training_job_config_path
imported but unusedRemove unused import
(F401)
38-38
: Remove the unused importdataclasses.asdict
.This import does not appear to be used in the test file. Removing it helps tidy up the code.
- from dataclasses import asdict
🧰 Tools
🪛 Ruff (0.8.2)
38-38:
dataclasses.asdict
imported but unusedRemove unused import:
dataclasses.asdict
(F401)
41-41
: Remove the unused importjson
.Although
load_sleap_config
relies onjson
, this test file itself does not calljson
directly. Removing the local import is recommended to satisfy static analysis warnings.- import json
🧰 Tools
🪛 Ruff (0.8.2)
41-41:
json
imported but unusedRemove unused import:
json
(F401)
sleap_nn/config/model_config.py (1)
837-1003
: Consider broader support for non-UNet backbones.Currently, this mapper only populates a
UNetConfig
and skips other potential backbone structures. If your legacy configs might includeconvnext
orswint
entries, consider handling them similarly. This ensures graceful fallback rather than silently ignoring those entries.sleap_nn/config/trainer_config.py (3)
237-371
: Add a docstring to thetrainer_mapper
function.Adding a docstring would improve the maintainability and readability of this function by documenting its purpose, parameters, and return value.
+""" +Maps a legacy configuration dictionary to a TrainerConfig object. + +Parameters: + legacy_config (dict): A dictionary containing legacy configuration settings. + +Returns: + TrainerConfig: An instance of TrainerConfig populated with values from legacy_config. +""" def trainer_mapper(legacy_config: dict) -> TrainerConfig:
305-309
: Simplify the regex for capitalizing the first letter.The regex to capitalize the first letter of the optimizer name is more complex than needed. Using the built-in
capitalize()
method would be more readable.- optimizer_name=re.sub( - r"^[a-z]", - lambda x: x.group().upper(), - legacy_config.get("optimization", {}).get("optimizer", "adam"), - ), + optimizer_name=legacy_config.get("optimization", {}).get("optimizer", "adam").capitalize(),
242-247
: Explain the purpose of commented-out code sections.There are multiple sections of commented-out code without explanation. This makes it difficult to understand whether these sections are incomplete implementations, deprecated code, or placeholders for future development.
Add explanatory comments at the beginning of commented-out sections, for example:
+ # NOTE: The following parameters are intentionally not mapped as they are not + # present in legacy SLEAP configurations or have different default values # num_workers=legacy_config.get("optimization", {}) # .get("train_data_loader", {}) # .get("num_workers", 0),Also applies to: 250-255, 258-260, 262-268, 271-281, 282-304
sleap_nn/config/data_config.py (2)
190-191
: Add a docstring to thedata_mapper
function.The function lacks documentation explaining its purpose, parameters, and return value, which would improve maintainability.
+""" +Maps a legacy configuration dictionary to a DataConfig object. + +Parameters: + legacy_config (dict): A dictionary containing legacy configuration settings. + +Returns: + DataConfig: An instance of DataConfig populated with values from legacy_config. +""" def data_mapper(legacy_config: dict) -> DataConfig:
236-239
: Document the normalization of uniform_noise_max_val.The function divides
uniform_noise_max_val
by 100.0, but this normalization is not documented and isn't applied to other similar values.uniform_noise_max=legacy_config.get("optimization", {}) .get("augmentation_config", {}) - .get("uniform_noise_max_val", 100.0) - / 100.0, + .get("uniform_noise_max_val", 100.0) / 100.0, # Convert percentage to fraction
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (7)
sleap_nn/config/data_config.py
(1 hunks)sleap_nn/config/model_config.py
(1 hunks)sleap_nn/config/trainer_config.py
(2 hunks)tests/assets/fixtures/datasets.py
(1 hunks)tests/assets/training_config.json
(1 hunks)tests/config/test_data_config.py
(2 hunks)tests/config/test_training_job_config.py
(2 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- tests/config/test_data_config.py
🧰 Additional context used
🧬 Code Definitions (1)
tests/config/test_training_job_config.py (4)
sleap_nn/config/training_job_config.py (1)
load_sleap_config
(155-167)sleap_nn/config/model_config.py (1)
ModelConfig
(755-834)sleap_nn/config/trainer_config.py (1)
TrainerConfig
(163-234)tests/assets/fixtures/datasets.py (2)
sleapnn_data_dir
(10-12)training_job_config_path
(16-18)
🪛 Ruff (0.8.2)
tests/assets/fixtures/datasets.py
4-4: omegaconf.OmegaConf
imported but unused
Remove unused import: omegaconf.OmegaConf
(F401)
tests/config/test_training_job_config.py
36-36: tests.assets.fixtures.datasets.sleapnn_data_dir
imported but unused
Remove unused import
(F401)
36-36: tests.assets.fixtures.datasets.training_job_config_path
imported but unused
Remove unused import
(F401)
38-38: dataclasses.asdict
imported but unused
Remove unused import: dataclasses.asdict
(F401)
41-41: json
imported but unused
Remove unused import: json
(F401)
235-235: Redefinition of unused training_job_config_path
from line 36
(F811)
⏰ Context from checks skipped due to timeout of 90000ms (3)
- GitHub Check: Tests (macos-14, Python 3.9)
- GitHub Check: Tests (windows-latest, Python 3.9)
- GitHub Check: Tests (ubuntu-latest, Python 3.9)
🔇 Additional comments (4)
tests/config/test_training_job_config.py (2)
235-235
: Pytest fixture overshadowing warning.Static analysis flags a redefinition of
training_job_config_path
. In pytest, naming the parameter the same as the fixture is intentional so that the fixture is injected automatically. This can safely be ignored unless you prefer a different naming approach to appease the linter.🧰 Tools
🪛 Ruff (0.8.2)
235-235: Redefinition of unused
training_job_config_path
from line 36(F811)
235-261
: Test coverage looks comprehensive.Verifying fields against the JSON config ensures correct loading. These assertions confirm each key aspect of the loaded config, which improves confidence in backward compatibility.
🧰 Tools
🪛 Ruff (0.8.2)
235-235: Redefinition of unused
training_job_config_path
from line 36(F811)
tests/assets/training_config.json (1)
1-218
: Configuration file creation approved.This JSON file thoroughly covers model, data, and optimization parameters for your legacy config. Adding it strengthens test coverage for the new
load_sleap_config
functionality.sleap_nn/config/model_config.py (1)
933-953
:centered_instance
head incorrectly usesCentroidConfMapsConfig
.As previously noted,
CenteredInstanceConfig
expects aCenteredInstanceConfMapsConfig
, not aCentroidConfMapsConfig
. This mismatch can lead to erroneous behavior.- CenteredInstanceConfig( - confmaps=CentroidConfMapsConfig( + CenteredInstanceConfig( + confmaps=CenteredInstanceConfMapsConfig( anchor_part=..., sigma=..., output_stride=..., part_names=... ) )
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🔭 Outside diff range comments (1)
tests/config/test_training_job_config.py (1)
236-279
: 🛠️ Refactor suggestionImprove exception handling and file management in test.
The test has a few issues that should be addressed:
- Unused variable
e
in the exception handler- Unsafe temporary file handling that doesn't guarantee cleanup
- Complex control flow with try-except block
Consider using a context manager for temporary file handling and simplifying the test:
def test_load_sleap_config_from_file(training_job_config_path): """Test the load_sleap_config function with a sample legacy configuration from a JSON file.""" # Path to the training_config.json file json_file_path = training_job_config_path # Load the configuration using the load_sleap_config method try: # Load the configuration using the load_sleap_config method config = load_sleap_config(TrainingJobConfig, json_file_path) - except MissingMandatoryValue as e: + except MissingMandatoryValue: with open(json_file_path, "r") as f: old_config = json.load(f) # Create a temporary file to hold the modified configuration - with tempfile.NamedTemporaryFile(delete=False, suffix='.json', mode='w') as temp_file: + with tempfile.NamedTemporaryFile(suffix='.json', mode='w', delete=False) as temp_file: old_config['data']['labels']['training_labels'] = "notMISSING" old_config['data']['labels']['validation_labels'] = "notMISSING" json.dump(old_config, temp_file) temp_file_path = temp_file.name - config = load_sleap_config(TrainingJobConfig, temp_file_path) - os.remove(temp_file_path) + try: + config = load_sleap_config(TrainingJobConfig, temp_file_path) + finally: + os.remove(temp_file_path)Also consider validating the exception message to ensure the correct failure is caught:
def test_load_sleap_config_missing_fields(training_job_config_path): """Test that load_sleap_config raises appropriate exceptions for missing fields.""" with tempfile.NamedTemporaryFile(suffix='.json', mode='w') as temp_file: # Create a config with missing mandatory fields json.dump({"data": {"labels": {}}}, temp_file) temp_file.flush() # Verify the exception is raised with the expected message with pytest.raises(MissingMandatoryValue, match="train_labels_path is missing"): load_sleap_config(TrainingJobConfig, temp_file.name)🧰 Tools
🪛 Ruff (0.8.2)
236-236: Redefinition of unused
training_job_config_path
from line 36(F811)
245-245: Local variable
e
is assigned to but never usedRemove assignment to unused variable
e
(F841)
🧹 Nitpick comments (4)
tests/config/test_data_config.py (1)
127-200
: Good test coverage fordata_mapper
function!The test provides comprehensive validation for all aspects of the
data_mapper
function, including preprocessing, augmentation, and skeleton configurations.Consider adding tests for error conditions:
- What happens when mandatory fields like
training_labels
orvalidation_labels
are missing?- How does the function handle edge cases or invalid values?
def test_data_mapper_missing_mandatory_fields(): """Test the data_mapper function with missing mandatory fields.""" # Test missing training_labels legacy_config = { "data": { "labels": { "validation_labels": "valid_path", } } } with pytest.raises(MissingMandatoryValue, match="train_labels_path is missing"): data_mapper(legacy_config) # Test missing validation_labels legacy_config = { "data": { "labels": { "training_labels": "valid_path", } } } with pytest.raises(MissingMandatoryValue, match="val_labels_path is missing"): data_mapper(legacy_config)tests/config/test_training_job_config.py (3)
31-36
: Clean up unused imports.Several imports are not being used in this file. Consider removing them to improve code clarity.
from sleap_nn.config.training_job_config import TrainingJobConfig from sleap_nn.config.training_job_config import load_sleap_config from sleap_nn.config.model_config import ModelConfig from sleap_nn.config.data_config import DataConfig from sleap_nn.config.trainer_config import TrainerConfig, EarlyStoppingConfig from sleap_nn.config.data_config import IntensityConfig -from tests.assets.fixtures.datasets import sleapnn_data_dir, training_job_config_path +from tests.assets.fixtures.datasets import training_job_config_path from omegaconf import OmegaConf, MissingMandatoryValue -from dataclasses import asdict import json -from omegaconf import MISSING🧰 Tools
🪛 Ruff (0.8.2)
36-36:
tests.assets.fixtures.datasets.sleapnn_data_dir
imported but unusedRemove unused import
(F401)
36-36:
tests.assets.fixtures.datasets.training_job_config_path
imported but unusedRemove unused import
(F401)
245-245
: Remove unused exception variable.The exception variable
e
is not being used in the exception handler.try: # Load the configuration using the load_sleap_config method config = load_sleap_config(TrainingJobConfig, json_file_path) - except MissingMandatoryValue as e: + except MissingMandatoryValue:🧰 Tools
🪛 Ruff (0.8.2)
245-245: Local variable
e
is assigned to but never usedRemove assignment to unused variable
e
(F841)
251-260
: Use context manager for safer temporary file handling.The current approach doesn't guarantee the temporary file will be deleted if an exception occurs during loading.
- with tempfile.NamedTemporaryFile(delete=False, suffix='.json', mode='w') as temp_file: - old_config['data']['labels']['training_labels'] = "notMISSING" - old_config['data']['labels']['validation_labels'] = "notMISSING" - - json.dump(old_config, temp_file) - temp_file_path = temp_file.name - - config = load_sleap_config(TrainingJobConfig, temp_file_path) - os.remove(temp_file_path) + with tempfile.NamedTemporaryFile(suffix='.json', mode='w', delete=False) as temp_file: + old_config['data']['labels']['training_labels'] = "notMISSING" + old_config['data']['labels']['validation_labels'] = "notMISSING" + + json.dump(old_config, temp_file) + temp_file_path = temp_file.name + + try: + config = load_sleap_config(TrainingJobConfig, temp_file_path) + finally: + if os.path.exists(temp_file_path): + os.remove(temp_file_path)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
sleap_nn/config/data_config.py
(2 hunks)tests/config/test_data_config.py
(2 hunks)tests/config/test_training_job_config.py
(2 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- sleap_nn/config/data_config.py
🧰 Additional context used
🧬 Code Definitions (2)
tests/config/test_data_config.py (1)
sleap_nn/config/data_config.py (1)
data_mapper
(190-305)
tests/config/test_training_job_config.py (2)
sleap_nn/config/training_job_config.py (2)
load_sleap_config
(155-167)TrainingJobConfig
(42-138)tests/assets/fixtures/datasets.py (2)
sleapnn_data_dir
(10-12)training_job_config_path
(16-18)
🪛 Ruff (0.8.2)
tests/config/test_training_job_config.py
36-36: tests.assets.fixtures.datasets.sleapnn_data_dir
imported but unused
Remove unused import
(F401)
36-36: tests.assets.fixtures.datasets.training_job_config_path
imported but unused
Remove unused import
(F401)
38-38: dataclasses.asdict
imported but unused
Remove unused import: dataclasses.asdict
(F401)
42-42: omegaconf.MISSING
imported but unused
Remove unused import: omegaconf.MISSING
(F401)
236-236: Redefinition of unused training_job_config_path
from line 36
(F811)
245-245: Local variable e
is assigned to but never used
Remove assignment to unused variable e
(F841)
⏰ Context from checks skipped due to timeout of 90000ms (3)
- GitHub Check: Tests (macos-14, Python 3.9)
- GitHub Check: Tests (windows-latest, Python 3.9)
- GitHub Check: Tests (ubuntu-latest, Python 3.9)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Nitpick comments (12)
sleap_nn/config/model_config.py (1)
835-836
: Remove unnecessary blank lines.These two blank lines appear to serve no functional purpose. Consider removing them to keep the code more concise and aligned with standard style conventions.
tests/config/test_data_config.py (1)
127-200
: Test provides comprehensive coverage of the data_mapper function.The test effectively validates all aspects of the
data_mapper
function fromsleap_nn/config/data_config.py
, including:
- Correct handling of labels paths
- Preprocessing configuration conversion
- Augmentation configuration handling
- Intensity configuration parameters
- Geometric transformation parameters
- Skeletons configuration mapping
Consider adding test cases for error conditions, such as missing mandatory values (e.g., when
train_labels_path
orval_labels_path
are missing). This would help ensure the function correctly raises appropriate exceptions in those scenarios.def test_data_mapper_missing_mandatory_values(): """Test that data_mapper raises appropriate exceptions for missing mandatory values.""" # Test missing train_labels_path legacy_config_missing_train = { "data": { "labels": { "validation_labels": "some_path.slp" } } } with pytest.raises(MissingMandatoryValue, match="train_labels_path is missing"): data_mapper(legacy_config_missing_train) # Test missing val_labels_path legacy_config_missing_val = { "data": { "labels": { "training_labels": "some_path.slp" } } } with pytest.raises(MissingMandatoryValue, match="val_labels_path is missing"): data_mapper(legacy_config_missing_val)tests/assets/fixtures/datasets.py (2)
4-4
: Remove unused import.
The static analysis tool detects thatOmegaConf
is imported but unused.Apply this diff to remove the unused import:
-from omegaconf import OmegaConf
🧰 Tools
🪛 Ruff (0.8.2)
4-4:
omegaconf.OmegaConf
imported but unusedRemove unused import:
omegaconf.OmegaConf
(F401)
34-35
: Update docstring to match file name.
The docstring says"Path to centered_instance_training_config file."
but the actual file is"centered_instance_with_scaling_training_config.json"
.- """Path to centered_instance_training_config file.""" + """Path to centered_instance_with_scaling_training_config file."""tests/config/test_training_job_config.py (8)
36-36
: Avoid wildcard import.
Usingimport *
can mask undefined names.Here is a diff showing an explicit import:
-from tests.assets.fixtures.datasets import * +from tests.assets.fixtures.datasets import ( + sleapnn_data_dir, + training_job_config_path, + bottomup_training_config_path, + centered_instance_training_config_path, + centered_instance_with_scaling_training_config_path, + centroid_training_config_path, + single_instance_training_config_path, + topdown_training_config_path, +)🧰 Tools
🪛 Ruff (0.8.2)
36-36:
from tests.assets.fixtures.datasets import *
used; unable to detect undefined names(F403)
38-38
: Remove unused import.
dataclasses.asdict
is not referenced in this file.-from dataclasses import asdict
🧰 Tools
🪛 Ruff (0.8.2)
38-38:
dataclasses.asdict
imported but unusedRemove unused import:
dataclasses.asdict
(F401)
42-42
: Remove unused import.
omegaconf.MISSING
is never used.-from omegaconf import MISSING
🧰 Tools
🪛 Ruff (0.8.2)
42-42:
omegaconf.MISSING
imported but unusedRemove unused import:
omegaconf.MISSING
(F401)
43-43
: Remove unused import.
pprint.pprint
is never used.-from pprint import pprint
🧰 Tools
🪛 Ruff (0.8.2)
43-43:
pprint.pprint
imported but unusedRemove unused import:
pprint.pprint
(F401)
245-245
: Remove assignment to unused variable.
The local variablee
is assigned but never used.- except MissingMandatoryValue as e: + except MissingMandatoryValue:🧰 Tools
🪛 Ruff (0.8.2)
245-245: Local variable
e
is assigned to but never usedRemove assignment to unused variable
e
(F841)
296-296
: Remove assignment to unused variable.
The local variablee
is assigned but never used.- except MissingMandatoryValue as e: + except MissingMandatoryValue:🧰 Tools
🪛 Ruff (0.8.2)
296-296: Local variable
e
is assigned to but never usedRemove assignment to unused variable
e
(F841)
340-340
: Remove assignment to unused variable.
Again,e
is not utilized.- except MissingMandatoryValue as e: + except MissingMandatoryValue:🧰 Tools
🪛 Ruff (0.8.2)
340-340: Local variable
e
is assigned to but never usedRemove assignment to unused variable
e
(F841)
237-285
: Refactor repeated code blocks in tests.
Each test repeats the same logic for handlingMissingMandatoryValue
, rewriting JSON, and loading the updated file.Consider extracting this into a helper function or fixture to reduce duplication:
def load_or_fix_legacy_config(json_file_path, required_keys=None): """ Attempt to load the config. If MissingMandatoryValue is raised, rewrite the config JSON with non-missing placeholders. """ from omegaconf import MissingMandatoryValue import tempfile, os, json if required_keys is None: required_keys = [("data", "labels", "training_labels"), ("data", "labels", "validation_labels")] try: return load_sleap_config(TrainingJobConfig, json_file_path) except MissingMandatoryValue: with open(json_file_path, "r") as f: old_config = json.load(f) for key_path in required_keys: cfg = old_config for kp in key_path[:-1]: cfg = cfg[kp] cfg[key_path[-1]] = "notMISSING" with tempfile.NamedTemporaryFile(delete=False, suffix=".json", mode="w") as temp_file: json.dump(old_config, temp_file) temp_file_path = temp_file.name config = load_sleap_config(TrainingJobConfig, temp_file_path) os.remove(temp_file_path) return configAlso applies to: 287-327, 329-371, 373-407, 409-439, 441-473, 475-510
🧰 Tools
🪛 Ruff (0.8.2)
245-245: Local variable
e
is assigned to but never usedRemove assignment to unused variable
e
(F841)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (11)
sleap_nn/config/data_config.py
(2 hunks)sleap_nn/config/model_config.py
(1 hunks)tests/assets/bottomup_training_config.json
(1 hunks)tests/assets/centered_instance_training_config.json
(1 hunks)tests/assets/centered_instance_with_scaling_training_config.json
(1 hunks)tests/assets/centroid_training_config.json
(1 hunks)tests/assets/fixtures/datasets.py
(1 hunks)tests/assets/single_instance_training_config.json
(1 hunks)tests/assets/topdown_training_config.json
(1 hunks)tests/config/test_data_config.py
(2 hunks)tests/config/test_training_job_config.py
(2 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- sleap_nn/config/data_config.py
🧰 Additional context used
🧬 Code Definitions (2)
tests/config/test_data_config.py (1)
sleap_nn/config/data_config.py (1)
data_mapper
(190-317)
tests/config/test_training_job_config.py (2)
sleap_nn/config/training_job_config.py (2)
load_sleap_config
(155-167)TrainingJobConfig
(42-138)tests/assets/fixtures/datasets.py (7)
training_job_config_path
(16-18)bottomup_training_config_path
(22-24)centered_instance_training_config_path
(28-30)centered_instance_with_scaling_training_config_path
(34-38)centroid_training_config_path
(42-44)single_instance_training_config_path
(48-50)topdown_training_config_path
(54-56)
🪛 Ruff (0.8.2)
tests/assets/fixtures/datasets.py
4-4: omegaconf.OmegaConf
imported but unused
Remove unused import: omegaconf.OmegaConf
(F401)
tests/config/test_training_job_config.py
36-36: from tests.assets.fixtures.datasets import *
used; unable to detect undefined names
(F403)
38-38: dataclasses.asdict
imported but unused
Remove unused import: dataclasses.asdict
(F401)
42-42: omegaconf.MISSING
imported but unused
Remove unused import: omegaconf.MISSING
(F401)
43-43: pprint.pprint
imported but unused
Remove unused import: pprint.pprint
(F401)
245-245: Local variable e
is assigned to but never used
Remove assignment to unused variable e
(F841)
296-296: Local variable e
is assigned to but never used
Remove assignment to unused variable e
(F841)
340-340: Local variable e
is assigned to but never used
Remove assignment to unused variable e
(F841)
⏰ Context from checks skipped due to timeout of 90000ms (3)
- GitHub Check: Tests (macos-14, Python 3.9)
- GitHub Check: Tests (windows-latest, Python 3.9)
- GitHub Check: Tests (ubuntu-latest, Python 3.9)
🔇 Additional comments (9)
sleap_nn/config/model_config.py (1)
837-1010
: Ensure consistency in mapped fields and defaults, and confirm handling of unimplemented parameters.
- The function comments indicate support for
init_weights
,pre_trained_weights
, andpretrained_backbone_weights
, but these parameters are commented out. Confirm that omitting them from the mapped values is intended or, if relevant, implement their extraction fromlegacy_config
.stem_stride
forUNetConfig
is given a default of16
here, whereas the class definition usesNone
. This may cause discrepancies if the legacy config expects the fallback to match the class default.convs_per_block
in the UNet config is also commented out. Verify whether you want to keep the class default (2
) or map a value fromlegacy_config
.- Confirm that
"multi_instance"
in the old config is indeed the correct indicator for thebottomup
head logic. If old configs had an explicit"bottomup"
key, consider adjusting accordingly.- The code does not handle fields like
offset_refinement
orloss_weight
forcentered_instance
or other heads (unless explicitly mapped). Ensure this omission is intentional.tests/assets/centered_instance_with_scaling_training_config.json (1)
1-226
: Verify newly added parameters align with code usage.
- The
"centered_instance"
block includes"offset_refinement"
and"loss_weight"
fields, but the currentCenteredInstanceConfig
class does not parse them. Verify whether these fields should be read and mapped to the code.- Setting
"center_on_part": null
might be intentional, but confirm the downstream logic handlesNone
gracefully, especially if you later rely on an anchor part.- Windows-style double backslashes in
"filename"
(line 225) are acceptable if needed for a Windows environment, but ensure any code consuming this path can handle it properly.tests/assets/centroid_training_config.json (1)
1-202
: Confirm “offset_refinement” usage for centroid head.In the
"centroid"
head configuration,"offset_refinement"
is set totrue
, but neitherCentroidConfig
norCentroidConfMapsConfig
includes a corresponding field. Ensure that any additional fields beyond"anchor_part"
,"sigma"
, and"output_stride"
are either needed or supported at runtime.tests/config/test_data_config.py (1)
20-20
: LGTM: Added data_mapper import to test module.The import of the new
data_mapper
function is correctly added to the existing import statement from the data_config module.tests/assets/bottomup_training_config.json (1)
1-219
:✅ Verification successful
Comprehensive bottomup training configuration template.
The configuration file is well-structured and covers all necessary parameters for bottomup training, including data preprocessing, model architecture, optimization settings, and output handling.
The training and validation labels are set to
null
, which may be intentional for a template file. Verify that tests or example code that uses this configuration handle these null values appropriately by either:
🏁 Script executed:
#!/bin/bash # Check if there are tests that verify null label paths are handled properly rg -A 10 "test.*bottomup.*config" --type py tests/Length of output: 4156
Review Update: Bottomup Training Config Verified
The bottomup training configuration in
tests/assets/bottomup_training_config.json
is structured well. The tests (in files such astests/training/test_model_trainer.py
,tests/test_train.py
,tests/data/test_streaming_datasets.py
,tests/data/test_get_data_chunks.py
, andtests/config/test_training_job_config.py
) confirm that the intentionalnull
values fortraining_labels
,validation_labels
, andtest_labels
are handled appropriately within the application flow.No further changes are required.
tests/assets/centered_instance_training_config.json (1)
1-206
:✅ Verification successful
Complete centered instance training configuration template.
The configuration provides all necessary parameters for centered-instance model training with appropriate structure and organization.
Similar to the bottomup configuration, the training and validation labels are set to
null
. Verify that code using this configuration handles these null values correctly:
🏁 Script executed:
#!/bin/bash # Check if there are tests for centered instance config with null labels rg -A 10 "test.*centered.*instance.*config" --type py tests/ # Check for potential validation of required fields rg -A 5 -B 5 "validation.*labels.*null" --type py sleap_nn/Length of output: 4424
Centered Instance Training Config Verified and Approved
The configuration template is complete and organized, mirroring the bottom‐up configuration by setting the training, validation, and test labels to
null
. The tests—such as those intests/config/test_training_job_config.py
,tests/data/test_centered_instance_data_chunks.py
, andtests/data/test_streaming_datasets.py
—successfully load and utilize this configuration, indicating that the code properly handles thesenull
values. No additional changes are required.tests/assets/fixtures/datasets.py (1)
1-2
: Docstring looks good.
Coverage of high-level purpose of dataset fixtures is clear and concise.tests/assets/topdown_training_config.json (1)
1-240
: Configuration file appears logically structured.
No syntactic or semantic issues are apparent. The sections (data, model, optimization, outputs) illustrate a comprehensive top-down approach. The skeleton nodes and edges are correctly nested.tests/config/test_training_job_config.py (1)
31-31
: New import ofload_sleap_config
is appropriate.
This import corresponds directly to the new tests, enabling JSON-based loading functionality.
"centroid": null, | ||
"centered_instance": null, | ||
"multi_instance": null | ||
} | ||
}, | ||
"optimization": { | ||
"preload_data": true, | ||
"augmentation_config": { | ||
"rotate": false, | ||
"rotation_min_angle": -180.0, | ||
"rotation_max_angle": 180.0, | ||
"translate": false, | ||
"translate_min": -5, | ||
"translate_max": 5, | ||
"scale": false, | ||
"scale_min": 0.9, | ||
"scale_max": 1.1, | ||
"uniform_noise": false, | ||
"uniform_noise_min_val": 0.0, | ||
"uniform_noise_max_val": 10.0, | ||
"gaussian_noise": false, | ||
"gaussian_noise_mean": 5.0, | ||
"gaussian_noise_stddev": 1.0, | ||
"contrast": false, | ||
"contrast_min_gamma": 0.5, | ||
"contrast_max_gamma": 2.0, | ||
"brightness": false, | ||
"brightness_min_val": 0.0, | ||
"brightness_max_val": 10.0, | ||
"random_crop": false, | ||
"random_crop_height": 256, | ||
"random_crop_width": 256 | ||
}, | ||
"online_shuffling": true, | ||
"shuffle_buffer_size": 128, | ||
"prefetch": true, | ||
"batch_size": 4, | ||
"batches_per_epoch": 200, | ||
"min_batches_per_epoch": 200, | ||
"val_batches_per_epoch": 10, | ||
"min_val_batches_per_epoch": 10, | ||
"epochs": 100, | ||
"optimizer": "adam", | ||
"initial_learning_rate": 0.001, | ||
"learning_rate_schedule": { | ||
"reduce_on_plateau": true, | ||
"reduction_factor": 0.5, | ||
"plateau_min_delta": 1e-06, | ||
"plateau_patience": 5, | ||
"plateau_cooldown": 3, | ||
"min_learning_rate": 1e-08 | ||
}, | ||
"hard_keypoint_mining": { | ||
"online_mining": false, | ||
"hard_to_easy_ratio": 2.0, | ||
"min_hard_keypoints": 2, | ||
"max_hard_keypoints": null, | ||
"loss_scale": 5.0 | ||
}, | ||
"early_stopping": { | ||
"stop_training_on_plateau": true, | ||
"plateau_min_delta": 1e-06, | ||
"plateau_patience": 10 | ||
} | ||
}, | ||
"outputs": { | ||
"save_outputs": true, | ||
"run_name": "minimal_robot.UNet.single_instance", | ||
"run_name_prefix": "", | ||
"run_name_suffix": "", | ||
"runs_folder": "", | ||
"tags": [ | ||
"" | ||
], | ||
"save_visualizations": false, | ||
"keep_viz_images": true, | ||
"log_to_csv": true, | ||
"checkpointing": { | ||
"initial_model": false, | ||
"best_model": true, | ||
"every_epoch": false, | ||
"latest_model": false, | ||
"final_model": false | ||
}, | ||
"tensorboard": { | ||
"write_logs": false, | ||
"loss_frequency": "epoch", | ||
"architecture_graph": false, | ||
"profile_graph": false, | ||
"visualizations": true | ||
}, | ||
"zmq": { | ||
"subscribe_to_controller": true, | ||
"controller_address": "tcp://127.0.0.1:9000", | ||
"controller_polling_timeout": 10, | ||
"publish_updates": true, | ||
"publish_address": "tcp://127.0.0.1:9001" | ||
} | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Verification agent
🧩 Analysis chain
Well-defined single instance training configuration.
This configuration file provides a complete setup for single-instance training with appropriate paths, model structure, and optimization parameters.
The training labels path ("tests/data/slp_hdf5/small_robot_minimal.slp"
) needs to exist for tests to pass. Let's verify this:
🏁 Script executed:
#!/bin/bash
# Check if the referenced training data file exists
if [ -f "tests/data/slp_hdf5/small_robot_minimal.slp" ]; then
echo "Training data file exists"
else
echo "ERROR: Training data file not found"
# List available files in that directory to help debugging
echo "Available files:"
ls -la tests/data/slp_hdf5/ 2>/dev/null || echo "Directory does not exist"
fi
Length of output: 269
Action Required: Missing Training Data File
The single instance training configuration is well-structured; however, the referenced training labels path ("tests/data/slp_hdf5/small_robot_minimal.slp"
) does not exist. This missing file (or directory) is causing tests to fail.
- Action: Please ensure that the file is added to the repository at the correct location, or update the configuration file to provide a valid path for the training labels.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🔭 Outside diff range comments (1)
sleap_nn/config/data_config.py (1)
190-313
: 🛠️ Refactor suggestionSeveral concerns in the new
data_mapper
function.
- Commented-out code: Key parameters like
provider
,user_instances_only
, etc., remain commented out. If they are required, consider uncommenting or removing them altogether to reduce confusion.- Forcing
use_augmentations_train
toTrue
: You currently override this with a hard-codedTrue
(line 310). Most designs expect reading it from the legacy config (e.g.,legacy_config.get("use_augmentations_train", False)
).- Potential
(None, None)
usage: Passing(None, None)
asscale
(lines 288–293) can break validations inGeometricConfig
. Confirm whether this is intended or provide a default numeric range if the config is missing a scale setting.Below is a snippet to let
legacy_config
decide if augmentations are used, and to supply a safer default for scale:- use_augmentations_train=True, + use_augmentations_train=bool( + legacy_config.get("optimization", {}) + .get("use_augmentations", False) + ), - scale=( - legacy_config.get("optimization", {}) - .get("augmentation_config", {}) - .get("scale_min", None), - legacy_config.get("optimization", {}) - .get("augmentation_config", {}) - .get("scale_max", None), - ), + scale=( + legacy_config.get("optimization", {}) + .get("augmentation_config", {}) + .get("scale_min", 0.9), + legacy_config.get("optimization", {}) + .get("augmentation_config", {}) + .get("scale_max", 1.1), + ),
♻️ Duplicate comments (1)
sleap_nn/config/training_job_config.py (1)
155-173
: 🛠️ Refactor suggestionMultiple issues with the
load_sleap_config
function.
- Lack of error handling: No
try/except
to handle file I/O errors and JSON parse errors (e.g.,FileNotFoundError
,json.JSONDecodeError
). This can lead to uncaught exceptions at runtime.@classmethod
usage: The function usescls
as its first parameter, but it isn’t decorated with@classmethod
, causing potential confusion about its intended usage.- Missing
check_output_strides
: Unlikefrom_yaml
andload_yaml
, this loader doesn’t invokeTrainingJobConfig.check_output_strides(config)
, potentially leading to inconsistent stride settings relative to the YAML loading path.- Return type mismatch: The function signature says it returns a
TrainerConfig
, but the actual return object is anOmegaConf
instance.Below is a suggested refactor that includes error handling, calls
check_output_strides
, and aligns the return type:+@classmethod -def load_sleap_config(cls, json_file_path: str) -> TrainerConfig: +def load_sleap_config(cls, json_file_path: str) -> OmegaConf: +import json try: with open(json_file_path, "r") as f: old_config = json.load(f) except FileNotFoundError: raise FileNotFoundError(f"Configuration file not found: {json_file_path}") except json.JSONDecodeError: raise ValueError(f"Invalid JSON in configuration file: {json_file_path}") data_config = data_mapper(old_config) model_config = model_mapper(old_config) trainer_config = trainer_mapper(old_config) config = cls( data_config=data_config, model_config=model_config, trainer_config=trainer_config, ) schema = OmegaConf.structured(config) config_omegaconf = OmegaConf.merge(schema, OmegaConf.create(asdict(config))) +config_omegaconf = cls.check_output_strides(config_omegaconf) OmegaConf.to_container(config_omegaconf, resolve=True, throw_on_missing=True) return config_omegaconf
🧹 Nitpick comments (3)
sleap_nn/config/data_config.py (1)
8-8
: Remove the unused import.
MissingMandatoryValue
is never invoked, so please remove it to address the linter’s warning and keep your imports clean.-from omegaconf import MISSING, MissingMandatoryValue +from omegaconf import MISSING🧰 Tools
🪛 Ruff (0.8.2)
8-8:
omegaconf.MissingMandatoryValue
imported but unusedRemove unused import:
omegaconf.MissingMandatoryValue
(F401)
tests/config/test_training_job_config.py (2)
36-37
: Avoid star-import if possible.
Usingimport *
from fixtures can clutter the namespace and make it harder to track dependencies. Prefer explicitly importing only what you need.🧰 Tools
🪛 Ruff (0.8.2)
36-36:
from tests.assets.fixtures.datasets import *
used; unable to detect undefined names(F403)
41-43
: Several unused imports.
json
is used in the tests below, so it’s fine.MISSING
(line 42) andpprint
(line 43) appear unused. Consider removing them along with any other imports flagged by static analysis for clarity.import json -from omegaconf import MISSING -from pprint import pprint🧰 Tools
🪛 Ruff (0.8.2)
42-42:
omegaconf.MISSING
imported but unusedRemove unused import:
omegaconf.MISSING
(F401)
43-43:
pprint.pprint
imported but unusedRemove unused import:
pprint.pprint
(F401)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
sleap_nn/config/data_config.py
(2 hunks)sleap_nn/config/training_job_config.py
(2 hunks)tests/config/test_training_job_config.py
(2 hunks)
🧰 Additional context used
🧬 Code Definitions (1)
tests/config/test_training_job_config.py (1)
sleap_nn/config/training_job_config.py (2)
load_sleap_config
(155-173)TrainingJobConfig
(42-138)
🪛 Ruff (0.8.2)
sleap_nn/config/data_config.py
8-8: omegaconf.MissingMandatoryValue
imported but unused
Remove unused import: omegaconf.MissingMandatoryValue
(F401)
tests/config/test_training_job_config.py
36-36: from tests.assets.fixtures.datasets import *
used; unable to detect undefined names
(F403)
38-38: dataclasses.asdict
imported but unused
Remove unused import: dataclasses.asdict
(F401)
42-42: omegaconf.MISSING
imported but unused
Remove unused import: omegaconf.MISSING
(F401)
43-43: pprint.pprint
imported but unused
Remove unused import: pprint.pprint
(F401)
245-245: Local variable e
is assigned to but never used
Remove assignment to unused variable e
(F841)
296-296: Local variable e
is assigned to but never used
Remove assignment to unused variable e
(F841)
339-339: Local variable e
is assigned to but never used
Remove assignment to unused variable e
(F841)
⏰ Context from checks skipped due to timeout of 90000ms (4)
- GitHub Check: Tests (macos-14, Python 3.9)
- GitHub Check: Tests (windows-latest, Python 3.9)
- GitHub Check: Tests (ubuntu-latest, Python 3.9)
- GitHub Check: Lint
🔇 Additional comments (10)
sleap_nn/config/training_job_config.py (2)
31-37
: No issues with new imports.
The newly introduced imports forjson
and the mapper functions (data_mapper
,model_mapper
,trainer_mapper
) look correct and necessary for the new JSON loading logic.
153-154
: No content changes in these lines.
These are simply blank lines, so there is nothing to comment on here.sleap_nn/config/data_config.py (1)
187-188
: Change toOptional[list]
for skeletons looks reasonable.
If the legacy configuration indeed returns a list of skeleton definitions, this type adjustment is appropriate.tests/config/test_training_job_config.py (7)
31-31
: No immediate issue.
Introducingload_sleap_config
into the test suite is valid given you’re adding JSON-related tests.
287-326
: Same repeated pattern of unusedValidationError as e
.
Consider applying the same fix or refactoring approach as in the previous test.🧰 Tools
🪛 Ruff (0.8.2)
296-296: Local variable
e
is assigned to but never usedRemove assignment to unused variable
e
(F841)
328-369
: Same repeated pattern of unusedValidationError as e
.🧰 Tools
🪛 Ruff (0.8.2)
339-339: Local variable
e
is assigned to but never usedRemove assignment to unused variable
e
(F841)
371-404
: Same repeated pattern of unusedValidationError as e
.
406-437
: Same repeated pattern of unusedValidationError as e
.
439-471
: Same repeated pattern of unusedValidationError as e
.
473-508
: Same repeated pattern of unusedValidationError as e
.
def test_load_sleap_config_from_file(training_job_config_path): | ||
"""Test the load_sleap_config function with a sample legacy configuration from a JSON file.""" | ||
|
||
json_file_path = training_job_config_path | ||
|
||
# Load the configuration using the load_sleap_config method | ||
try: | ||
config = load_sleap_config(TrainingJobConfig, json_file_path) | ||
except ValidationError as e: | ||
|
||
with open(json_file_path, "r") as f: | ||
old_config = json.load(f) | ||
|
||
# Create a temporary file to hold the modified configuration | ||
with tempfile.NamedTemporaryFile( | ||
delete=False, suffix=".json", mode="w" | ||
) as temp_file: | ||
old_config["data"]["labels"]["training_labels"] = "notMISSING" | ||
old_config["data"]["labels"]["validation_labels"] = "notMISSING" | ||
|
||
json.dump(old_config, temp_file) | ||
temp_file_path = temp_file.name | ||
|
||
config = load_sleap_config(TrainingJobConfig, temp_file_path) | ||
os.remove(temp_file_path) | ||
|
||
# Assertions to check if the output matches expected values | ||
assert ( | ||
config.data_config.train_labels_path == "notMISSING" | ||
) # As per the temp JSON file | ||
assert ( | ||
config.data_config.val_labels_path == "notMISSING" | ||
) # As per the temp JSON file | ||
assert config.model_config.backbone_config.unet.filters == 8 | ||
assert config.model_config.backbone_config.unet.max_stride == 16 | ||
assert config.trainer_config.max_epochs == 200 | ||
assert config.trainer_config.optimizer_name == "Adam" | ||
assert config.trainer_config.optimizer.lr == 0.0001 | ||
assert config.trainer_config.trainer_devices == "auto" # Default value | ||
assert config.trainer_config.trainer_accelerator == "auto" # Default value | ||
assert config.trainer_config.enable_progress_bar is True # Default value | ||
assert config.trainer_config.train_data_loader.batch_size == 4 # From the JSON file | ||
assert ( | ||
config.trainer_config.lr_scheduler.reduce_lr_on_plateau is not None | ||
) # From the JSON file | ||
assert ( | ||
config.trainer_config.early_stopping.stop_training_on_plateau is True | ||
) # From the JSON file | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Refactor repeated error handling pattern.
The test tries loading the JSON config, catches ValidationError as e
, but never uses e
. Also, lines for setting "training_labels"
and "validation_labels"
to "notMISSING"
are duplicated across multiple tests. Consider factoring this into a shared helper or fixture for maintainability.
try:
config = load_sleap_config(TrainingJobConfig, json_file_path)
except ValidationError as e:
- pass # e is never used
+ logger.debug(f"Caught ValidationError: {e}")
+ # or remove `as e` if you don't need to reference it
Committable suggestion skipped: line range outside the PR's diff.
🧰 Tools
🪛 Ruff (0.8.2)
245-245: Local variable e
is assigned to but never used
Remove assignment to unused variable e
(F841)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
♻️ Duplicate comments (2)
sleap_nn/config/data_config.py (2)
295-302
: 🛠️ Refactor suggestionFix scale parameter default values
The scale parameter in GeometricConfig is set to
None
values, which could conflict with validators that expect numeric values. In the GeometricConfig class definition (line 113), the default is(0.9, 1.1)
.scale=( legacy_config.get("optimization", {}) .get("augmentation_config", {}) - .get("scale_min", None), + .get("scale_min", 0.9), legacy_config.get("optimization", {}) .get("augmentation_config", {}) - .get("scale_max", None), + .get("scale_max", 1.1), ),
315-318
:⚠️ Potential issueFix inconsistent use_augmentations_train handling
There are two issues:
use_augmentations_train
is hardcoded toTrue
regardless of what's in the legacy configuration- The conditional check for setting
augmentation_config
toNone
is commented outThis could lead to inconsistencies between the augmentation configuration and whether augmentations are actually used.
) - # if legacy_config.get("use_augmentations_train", False) - # else None + if legacy_config.get("optimization", {}).get("use_augmentations", True) + else None ), - use_augmentations_train=True, + use_augmentations_train=legacy_config.get("optimization", {}).get("use_augmentations", True),
🧹 Nitpick comments (12)
sleap_nn/config/trainer_config.py (4)
313-317
: Consider simplifying the regex for capitalizing the optimizer name.The regex approach works but could be simplified using standard string methods.
- optimizer_name=re.sub( - r"^[a-z]", - lambda x: x.group().upper(), - legacy_config.get("optimization", {}).get("optimizer", "adam"), - ), + optimizer_name=legacy_config.get("optimization", {}).get("optimizer", "adam").capitalize(),
270-311
: Consider adding a TODO comment for commented-out configurations.There are numerous commented-out configurations in this function. If these are planned for future implementation, consider adding a TODO comment explaining the plan. If they're not needed, consider removing them for better code clarity.
324-363
: Consider adding a helper function for nested dictionary access.The deeply nested
get()
calls reduce readability. Consider extracting a helper function to make the code more maintainable.def get_nested(d, *keys, default=None): """Get a value from a nested dictionary by a sequence of keys.""" result = d for key in keys: if not isinstance(result, dict): return default result = result.get(key, {}) return result if result != {} else default # Usage example: patience = get_nested(legacy_config, "optimization", "learning_rate_schedule", "plateau_patience", default=10)
349-358
: Inconsistent path access could lead to confusion.Some parameters are accessed from "learning_rate_schedule" while others might come from "lr_scheduler" (in commented code). This inconsistency could lead to confusion.
Consider documenting why different paths are used or standardize on a consistent approach if possible.
sleap_nn/config/model_config.py (4)
846-895
: Consider simplifying repetitive dictionary access patternsThe code contains many repetitive nested
.get()
calls which make the code verbose and harder to maintain. Consider creating a helper function to simplify this pattern.+def _get_nested(config, path, default=None): + """Get a value from a nested dictionary using a list of keys. + + Args: + config: Dictionary to access + path: List of keys to traverse + default: Default value if path doesn't exist + + Returns: + Value at the path or default if not found + """ + result = config + for key in path: + if not isinstance(result, dict): + return default + result = result.get(key, {}) + return result if result != {} else default def model_mapper(legacy_config: dict) -> ModelConfig: """Map the legacy model configuration to the new model configuration. Args: legacy_config: A dictionary containing the legacy model configuration. Returns: An instance of `ModelConfig` with the mapped configuration. """ + model_config = legacy_config.get("model", {}) return ModelConfig( # init_weights=legacy_config.get("init_weights", "default"), # pre_trained_weights not in old config # pretrained_backbone_weights=legacy_config.get("PretrainedEncoderConfig")?? # i think its different # pretrained_head_weights not in old config backbone_config=BackboneConfig( unet=( UNetConfig( # in_channels=legacy_config.get("backbone", {}).get("in_channels", 1), # kernel_size=legacy_config.get("backbone", {}).get("kernel_size", 3), - filters=legacy_config.get("model", {}) - .get("backbone", {}) - .get("unet", {}) - .get("filters", 32), + filters=_get_nested(model_config, ["backbone", "unet", "filters"], 32),This pattern could be applied throughout the function to make it more maintainable.
896-920
: Remove unnecessary nested parenthesesThe SingleInstanceConfig has an unnecessary level of parentheses in the conditional, which can be simplified.
head_configs=HeadConfig( single_instance=( - ( - SingleInstanceConfig( - confmaps=SingleInstanceConfMapsConfig( - part_names=legacy_config.get("model", {}) - .get("heads", {}) - .get("single_instance", {}) - .get("part_names"), - sigma=legacy_config.get("model", {}) - .get("heads", {}) - .get("single_instance", {}) - .get("sigma", 5.0), - output_stride=legacy_config.get("model", {}) - .get("heads", {}) - .get("single_instance", {}) - .get("output_stride", 1), - ) - ) - ) + SingleInstanceConfig( + confmaps=SingleInstanceConfMapsConfig( + part_names=legacy_config.get("model", {}) + .get("heads", {}) + .get("single_instance", {}) + .get("part_names"), + sigma=legacy_config.get("model", {}) + .get("heads", {}) + .get("single_instance", {}) + .get("sigma", 5.0), + output_stride=legacy_config.get("model", {}) + .get("heads", {}) + .get("single_instance", {}) + .get("output_stride", 1), + ) + ) if legacy_config.get("model", {}) .get("heads", {}) .get("single_instance") else None ),
969-1016
: The bottomup head mapping uses multi_instance pathsGood job on adding the conditional check to maintain the
@oneof
constraint for the bottomup head. I noticed that the legacy configuration key ismulti_instance
while the new configuration key isbottomup
. This seems intentional, but it would be helpful to add a comment explaining this mapping for future maintenance.bottomup=( BottomUpConfig( + # Note: bottomup head was previously called multi_instance in legacy configs confmaps=BottomUpConfMapsConfig( loss_weight=legacy_config.get("model", {}) .get("heads", {}) .get("multi_instance", {})
837-1018
: Add input validation for legacy_configThe function assumes that
legacy_config
is a valid dictionary, but doesn't validate this. Consider adding input validation to handle edge cases gracefully.def model_mapper(legacy_config: dict) -> ModelConfig: """Map the legacy model configuration to the new model configuration. Args: legacy_config: A dictionary containing the legacy model configuration. Returns: An instance of `ModelConfig` with the mapped configuration. """ + if not isinstance(legacy_config, dict): + logger.warning("Invalid legacy_config provided, using defaults") + legacy_config = {} + return ModelConfig(sleap_nn/config/data_config.py (4)
8-8
: Remove unused importThe
MissingMandatoryValue
exception is imported but never used in the code.-from omegaconf import MISSING, MissingMandatoryValue +from omegaconf import MISSING🧰 Tools
🪛 Ruff (0.8.2)
8-8:
omegaconf.MissingMandatoryValue
imported but unusedRemove unused import:
omegaconf.MissingMandatoryValue
(F401)
209-216
: Add explanation for commented-out parametersSeveral parameters are commented out without explanation. If these parameters aren't used in legacy configurations, add a comment explaining why they're commented out. If they should be included, uncomment them.
- # provider=legacy_config.get("provider", "LabelsReader"), - # user_instances_only=legacy_config.get("user_instances_only", True), - # data_pipeline_fw=legacy_config.get("data_pipeline_fw", "torch_dataset"), - # np_chunks_path=legacy_config.get("np_chunks_path"), - # litdata_chunks_path=legacy_config.get("litdata_chunks_path"), - # use_existing_chunks=legacy_config.get("use_existing_chunks", False), - # chunk_size=int(legacy_config.get("chunk_size", 100)), - # delete_chunks_after_training=legacy_config.get("delete_chunks_after_training", True), + # These parameters are not part of the legacy SLEAP configuration format + # provider=legacy_config.get("provider", "LabelsReader"), + # user_instances_only=legacy_config.get("user_instances_only", True), + # data_pipeline_fw=legacy_config.get("data_pipeline_fw", "torch_dataset"), + # np_chunks_path=legacy_config.get("np_chunks_path"), + # litdata_chunks_path=legacy_config.get("litdata_chunks_path"), + # use_existing_chunks=legacy_config.get("use_existing_chunks", False), + # chunk_size=int(legacy_config.get("chunk_size", 100)), + # delete_chunks_after_training=legacy_config.get("delete_chunks_after_training", True),
303-312
: Explain commented-out geometric augmentation parametersSeveral geometric augmentation parameters are commented out without explanation. If these parameters aren't available in legacy configurations, add a comment explaining why they're commented out. If they should be included, uncomment them.
- # translate_width=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("translate_width", 0.2), - # translate_height=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("translate_height", 0.2), - # affine_p=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("affine_p", 0.0), - # erase_scale_min=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("erase_scale_min", 0.0001), - # erase_scale_max=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("erase_scale_max", 0.01), - # erase_ratio_min=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("erase_ratio_min", 1.0), - # erase_ratio_max=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("erase_ratio_max", 1.0), - # erase_p=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("erase_p", 0.0), - # mixup_lambda=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("mixup_lambda", [0.01, 0.05]), - # mixup_p=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("mixup_p", 0.0), + # These geometric augmentation parameters are not available in legacy SLEAP configuration format + # translate_width=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("translate_width", 0.2), + # translate_height=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("translate_height", 0.2), + # affine_p=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("affine_p", 0.0), + # erase_scale_min=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("erase_scale_min", 0.0001), + # erase_scale_max=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("erase_scale_max", 0.01), + # erase_ratio_min=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("erase_ratio_min", 1.0), + # erase_ratio_max=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("erase_ratio_max", 1.0), + # erase_p=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("erase_p", 0.0), + # mixup_lambda=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("mixup_lambda", [0.01, 0.05]), + # mixup_p=legacy_config.get("optimization", {}).get("augmentation_config", {}).get("mixup_p", 0.0),
318-318
: Add unit tests for data_mapper functionThe
data_mapper
function is complex and handles critical configuration mapping, but there don't appear to be tests for it. Consider adding unit tests to verify it handles various legacy configuration formats correctly.Would you like me to generate example unit tests for the
data_mapper
function?
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
sleap_nn/config/data_config.py
(2 hunks)sleap_nn/config/model_config.py
(1 hunks)sleap_nn/config/trainer_config.py
(2 hunks)sleap_nn/config/training_job_config.py
(2 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- sleap_nn/config/training_job_config.py
🧰 Additional context used
🪛 Ruff (0.8.2)
sleap_nn/config/data_config.py
8-8: omegaconf.MissingMandatoryValue
imported but unused
Remove unused import: omegaconf.MissingMandatoryValue
(F401)
⏰ Context from checks skipped due to timeout of 90000ms (4)
- GitHub Check: Tests (macos-14, Python 3.9)
- GitHub Check: Tests (windows-latest, Python 3.9)
- GitHub Check: Tests (ubuntu-latest, Python 3.9)
- GitHub Check: Lint
🔇 Additional comments (5)
sleap_nn/config/trainer_config.py (4)
10-10
: LGTM on the new import.The
re
module import is required for the capitalize operation used in thetrainer_mapper
function.
237-245
: Function signature and docstring look good.The docstring clearly explains the purpose, parameters, and return value of the function.
247-264
: LGTM on the DataLoader configurations.The train and validation DataLoader configurations correctly set the batch_size and shuffle parameters from the legacy configuration.
364-378
: LGTM on the early stopping configuration.The early stopping configuration is correctly conditioned on the presence of learning_rate_schedule and properly maps all required parameters.
sleap_nn/config/model_config.py (1)
837-845
: Good documentation and parameter validationThe function is well-documented with clear docstring and parameter descriptions. The mapping functionality should help with backward compatibility for legacy configurations.
.get("backbone", {}) | ||
.get("unet", {}) | ||
.get("stacks", 1), | ||
# convs_per_block=2, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Verification agent
🧩 Analysis chain
Commented convs_per_block parameter
The convs_per_block
parameter is commented out. Should this be mapped from the legacy configuration or always set to a default value? If it should be mapped, uncomment and provide the correct mapping.
🏁 Script executed:
#!/bin/bash
# Check if convs_per_block was a configurable parameter in legacy configs
rg -i "convs_per_block" --type py | grep -v "model_config.py"
Length of output: 4207
Action Required - Map the Legacy convs_per_block Parameter Correctly
Based on the legacy configuration and tests, the convs_per_block
parameter is indeed expected to be configurable (with a common default of 2). Please update the mapping in sleap_nn/config/model_config.py
so that if a user provides convs_per_block
in their configuration it is picked up; otherwise, it should default to 2. For example, replace the commented line with something like:
convs_per_block = config.get("convs_per_block", 2)
- Location:
sleap_nn/config/model_config.py
(around line 884) - Action: Uncomment the parameter and map it appropriately from the legacy configuration.
else None | ||
), | ||
centroid=( | ||
CentroidConfig( | ||
confmaps=CentroidConfMapsConfig( | ||
anchor_part=legacy_config.get("model", {}) | ||
.get("heads", {}) | ||
.get("centroid", {}) | ||
.get("anchor_part"), | ||
sigma=legacy_config.get("model", {}) | ||
.get("heads", {}) | ||
.get("centroid", {}) | ||
.get("sigma", 5.0), | ||
output_stride=legacy_config.get("model", {}) | ||
.get("heads", {}) | ||
.get("centroid", {}) | ||
.get("output_stride", 1), | ||
) | ||
) | ||
if legacy_config.get("model", {}).get("heads", {}).get("centroid") | ||
else None | ||
), | ||
centered_instance=( | ||
CenteredInstanceConfig( | ||
confmaps=CenteredInstanceConfMapsConfig( | ||
anchor_part=legacy_config.get("model", {}) | ||
.get("heads", {}) | ||
.get("centered_instance", {}) | ||
.get("anchor_part"), | ||
sigma=legacy_config.get("model", {}) | ||
.get("heads", {}) | ||
.get("centered_instance", {}) | ||
.get("sigma", 5.0), | ||
output_stride=legacy_config.get("model", {}) | ||
.get("heads", {}) | ||
.get("centered_instance", {}) | ||
.get("output_stride", 1), | ||
part_names=legacy_config.get("model", {}) | ||
.get("heads", {}) | ||
.get("centered_instance", {}) | ||
.get("part_names", None), | ||
) | ||
) | ||
if legacy_config.get("model", {}) | ||
.get("heads", {}) | ||
.get("centered_instance") | ||
else None | ||
), | ||
bottomup=( | ||
BottomUpConfig( | ||
confmaps=BottomUpConfMapsConfig( | ||
loss_weight=legacy_config.get("model", {}) | ||
.get("heads", {}) | ||
.get("multi_instance", {}) | ||
.get("confmaps", {}) | ||
.get("loss_weight", None), | ||
sigma=legacy_config.get("model", {}) | ||
.get("heads", {}) | ||
.get("multi_instance", {}) | ||
.get("confmaps", {}) | ||
.get("sigma", 5.0), | ||
output_stride=legacy_config.get("model", {}) | ||
.get("heads", {}) | ||
.get("multi_instance", {}) | ||
.get("confmaps", {}) | ||
.get("output_stride", 1), | ||
part_names=legacy_config.get("model", {}) | ||
.get("heads", {}) | ||
.get("multi_instance", {}) | ||
.get("confmaps", {}) | ||
.get("part_names", None), | ||
), | ||
pafs=PAFConfig( | ||
edges=legacy_config.get("model", {}) | ||
.get("heads", {}) | ||
.get("multi_instance", {}) | ||
.get("pafs", {}) | ||
.get("edges", None), | ||
sigma=legacy_config.get("model", {}) | ||
.get("heads", {}) | ||
.get("multi_instance", {}) | ||
.get("pafs", {}) | ||
.get("sigma", 15.0), | ||
output_stride=legacy_config.get("model", {}) | ||
.get("heads", {}) | ||
.get("multi_instance", {}) | ||
.get("pafs", {}) | ||
.get("output_stride", 1), | ||
loss_weight=legacy_config.get("model", {}) | ||
.get("heads", {}) | ||
.get("multi_instance", {}) | ||
.get("pafs", {}) | ||
.get("loss_weight", None), | ||
), | ||
) | ||
if legacy_config.get("model", {}).get("heads", {}).get("multi_instance") | ||
else None | ||
), | ||
), | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Verification agent
🧩 Analysis chain
Consider adding support for other backbone types
Currently, the function only handles UNet backbones but doesn't have logic for other backbone types like ConvNext or SwinT that might have been in legacy configurations. Consider adding support for these if they were present in legacy configurations.
To verify whether other backbones were used in legacy configurations, you could check for their presence in the code:
🏁 Script executed:
#!/bin/bash
# Search for any mentions of convnext or swint in legacy config handling
rg -i "backbone.*conv|conv.*backbone|backbone.*swin|swin.*backbone" --type py
Length of output: 6718
Enhance the model_mapper
function to support additional backbone types
The current mapping only handles the UNet backbone configuration. However, our search results show that the codebase (including tests and training/inference modules) references other backbone types—specifically, convnext
and swint
. To ensure consistency with how these backbones are used throughout the project, please consider adding corresponding mapping logic for them in the BackboneConfig
.
- In
sleap_nn/config/model_config.py
(lines 837-1018), the mapper currently checks for a legacyunet
configuration but leaves placeholders (and comments) forconvnext
andswint
. - The grep output confirms that these backbones are used in multiple files (e.g.,
sleap_nn/train.py
,sleap_nn/inference/predictors.py
, and various tests). - Consider introducing conditional logic that inspects the legacy configuration for keys corresponding to
convnext
andswint
and maps them to the appropriate configuration classes (e.g.,ConvNextConfig
andSwinTConfig
).
@@ -184,4 +184,137 @@ class DataConfig: | |||
preprocessing: PreprocessingConfig = field(factory=PreprocessingConfig) | |||
use_augmentations_train: bool = False | |||
augmentation_config: Optional[AugmentationConfig] = None | |||
skeletons: Optional[dict] = None | |||
skeletons: Optional[list] = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Verification agent
🧩 Analysis chain
Verify type change from dict to list for skeletons
The type of skeletons
attribute has been changed from Optional[dict]
to Optional[list]
, but the function still retrieves it as if it were a dictionary structure from the legacy configuration. Ensure this change is intentional and compatible with how skeletons are used elsewhere in the codebase.
Run the following to check how skeletons are used elsewhere:
🏁 Script executed:
#!/bin/bash
# Check the structure of skeletons and where it's used in the codebase
rg -A 5 -B 5 "skeletons" --type py | grep -v "Optional\[list\]" | grep -v "Optional\[dict\]"
Length of output: 65229
Mismatch between declared type and legacy usage for skeletons
Our verification shows that many parts of the codebase—including tests such as in tests/config/test_data_config.py
and functions like get_skeleton_from_config
in sleap_nn/inference/utils.py
—continue to treat the skeleton configuration as a dictionary (e.g., by accessing keys or iterating over items). The type annotation change in sleap_nn/config/data_config.py
(line 187):
skeletons: Optional[list] = None
introduces an inconsistency. Please confirm whether the intent is to switch to a list-based representation. If so, the legacy retrieval logic and downstream usage must be updated accordingly. If not, revert the type annotation to reflect the expected dictionary format.
Summary by CodeRabbit
trainer_mapper
function to ensure default values are correctly applied for unspecified parameters.trainer_mapper
,model_mapper
, anddata_mapper
functions, ensuring correct configuration outputs.